{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tune neural networks with lexicographic preference across objectives\n",
    "This example is to tune neural networks model with two objectives \"error_rate\", \"flops\" on FashionMnist dataset. \n",
    "\n",
    "**Requirements.** This notebook requires:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install torch torchvision flaml[blendsearch,ray] thop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import thop\n",
    "import torch.nn as nn\n",
    "from flaml import tune\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "DEVICE = torch.device(\"cpu\")\n",
    "BATCHSIZE = 128\n",
    "N_TRAIN_EXAMPLES = BATCHSIZE * 30\n",
    "N_VALID_EXAMPLES = BATCHSIZE * 10\n",
    "data_dir = os.path.abspath(\"data\")\n",
    "\n",
    "train_dataset = torchvision.datasets.FashionMNIST(\n",
    "    data_dir,\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=torchvision.transforms.ToTensor(),\n",
    ")\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),\n",
    "    batch_size=BATCHSIZE,\n",
    "    shuffle=True,\n",
    ")\n",
    "\n",
    "val_dataset = torchvision.datasets.FashionMNIST(\n",
    "    data_dir, train=False, transform=torchvision.transforms.ToTensor()\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),\n",
    "    batch_size=BATCHSIZE,\n",
    "    shuffle=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specify the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def define_model(configuration):\n",
    "    n_layers = configuration[\"n_layers\"]\n",
    "    layers = []\n",
    "    in_features = 28 * 28\n",
    "    for i in range(n_layers):\n",
    "        out_features = configuration[\"n_units_l{}\".format(i)]\n",
    "        layers.append(nn.Linear(in_features, out_features))\n",
    "        layers.append(nn.ReLU())\n",
    "        p = configuration[\"dropout_{}\".format(i)]\n",
    "        layers.append(nn.Dropout(p))\n",
    "        in_features = out_features\n",
    "    layers.append(nn.Linear(in_features, 10))\n",
    "    layers.append(nn.LogSoftmax(dim=1))\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, optimizer, train_loader):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)\n",
    "        optimizer.zero_grad()\n",
    "        F.nll_loss(model(data), target).backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_model(model, valid_loader):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (data, target) in enumerate(valid_loader):\n",
    "            data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)\n",
    "            pred = model(data).argmax(dim=1, keepdim=True)\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    accuracy = correct / N_VALID_EXAMPLES\n",
    "    flops, params = thop.profile(\n",
    "        model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),), verbose=False\n",
    "    )\n",
    "    return np.log2(flops), 1 - accuracy, params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_function(configuration):\n",
    "    model = define_model(configuration).to(DEVICE)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), configuration[\"lr\"])\n",
    "    n_epoch = configuration[\"n_epoch\"]\n",
    "    for epoch in range(n_epoch):\n",
    "        train_model(model, optimizer, train_loader)\n",
    "    flops, error_rate, params = eval_model(model, val_loader)\n",
    "    return {\"error_rate\": error_rate, \"flops\": flops, \"params\": params}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lexicographic information across objectives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lexico_objectives = {}\n",
    "lexico_objectives[\"metrics\"] = [\"error_rate\", \"flops\"]\n",
    "lexico_objectives[\"tolerances\"] = {\"error_rate\": 0.02, \"flops\": 0.0}\n",
    "lexico_objectives[\"targets\"] = {\"error_rate\": 0.0, \"flops\": 0.0}\n",
    "lexico_objectives[\"modes\"] = [\"min\", \"min\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Search space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "search_space = {\n",
    "    \"n_layers\": tune.randint(lower=1, upper=3),\n",
    "    \"n_units_l0\": tune.randint(lower=4, upper=128),\n",
    "    \"n_units_l1\": tune.randint(lower=4, upper=128),\n",
    "    \"n_units_l2\": tune.randint(lower=4, upper=128),\n",
    "    \"dropout_0\": tune.uniform(lower=0.2, upper=0.5),\n",
    "    \"dropout_1\": tune.uniform(lower=0.2, upper=0.5),\n",
    "    \"dropout_2\": tune.uniform(lower=0.2, upper=0.5),\n",
    "    \"lr\": tune.loguniform(lower=1e-5, upper=1e-1),\n",
    "    \"n_epoch\": tune.randint(lower=1, upper=20),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Launch the tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "low_cost_partial_config = {\n",
    "    \"n_layers\": 1,\n",
    "    \"n_units_l0\": 4,\n",
    "    \"n_units_l1\": 4,\n",
    "    \"n_units_l2\": 4,\n",
    "    \"n_epoch\": 1,\n",
    "}\n",
    "\n",
    "analysis = tune.run(\n",
    "    evaluate_function,\n",
    "    num_samples=-1,\n",
    "    time_budget_s=100,\n",
    "    config=search_space,\n",
    "    use_ray=False,\n",
    "    lexico_objectives=lexico_objectives,\n",
    "    low_cost_partial_config=low_cost_partial_config,\n",
    ")\n",
    "result = analysis.best_result\n",
    "print(result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.14 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.14"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
